# Neural Causal Model

import os
import time
from typing import Any, Optional, Tuple

import networkx as nx
from numpy.typing import NDArray
import pandas as pd
import torch
from torch.utils.data import DataLoader
from custom_models.CustomCausalModel import CustomCausalModel

from custom_models.neural_causal.causal_graph import CausalGraph

from custom_models.neural_causal.scm.nn.Simple import Simple_binary, Simple_discrete
from custom_models.neural_causal.scm.nn.continuous import Continuous
from custom_models.neural_causal.scm.ncm import NCM as NCM_model
import numpy as np
from torch.utils.data import Dataset

import pandas as pd


class ContinuousDataset(Dataset):
    def __init__(self, data_dict, full_ds=False):
        self.data_dict = data_dict
        self.vars = list(data_dict.keys())
        self.full_ds = full_ds
        self.length = 1 if full_ds else len(data_dict[self.vars[0]])

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if self.full_ds:
            return self.data_dict
        item = {}
        for var in self.vars:
            item[var] = self.data_dict[var][idx]
        return item


class NCM(CustomCausalModel):
    def __init__(self, causal_graph: nx.DiGraph):
        self.causal_graph = causal_graph
        return

    def identify_effect(
        self,
        treatment: Optional[dict[str, float]] = {},
        outcome: Optional[dict[str, float]] = {},
        obs_data: Optional[pd.DataFrame] = None,
        int_data: Optional[pd.DataFrame] = None,
        method_params: dict[str, Any] = {},
        seed: Optional[int] = None,
        save_dir: Optional[str] = None,
    ) -> dict[str, Any]:
        return 

    def fit(
        self,
        data: Optional[pd.DataFrame] = None,
        int_table: Optional[pd.DataFrame] = None,
        method_params: dict[str, Any] = {},
        seed: Optional[int] = None,
        save_dir: Optional[str] = None,
        outcome: Optional[str] = None,
        treatment: Optional[dict[str, float]] = {},
        evidence: dict[str, float] = {},
    ) -> dict[str, Any]:
        """
        Train NCM
        """

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        torch.manual_seed(seed)
        single_batches_train = method_params["batch_size"] >= data.shape[0]

        # Check if we need to use interventional data
        if all(int_table == 0):
            use_interventional_data = False
        else:
            use_interventional_data = True

        data_dict = {}
        var_names = data.columns
        data_tensor = torch.Tensor(data.values).to(device)

        cont_vars, dis_vars, bin_vars = self.parse_vars(data)

        for i in range(len(var_names)):
            data_dict[var_names[i]] = (
                data_tensor[:, i].unsqueeze(1).to(device)
            )  # TODO check if these are the correct gene_names

        train_loader = DataLoader(
            ContinuousDataset(data_dict=data_dict),
            method_params["batch_size"],
            shuffle=True,
        )

        # get causal graph
        directed_edges = []
        bidirected_edges = []
        edges_list = list(self.causal_graph.edges())
        for edge in edges_list:
            edge_inv = (edge[1], edge[0])
            if edge_inv in edges_list:
                if edge_inv not in bidirected_edges and edge not in bidirected_edges:
                    bidirected_edges.append(edge)
            else:
                directed_edges.append(edge)

        cg = CausalGraph(
            V=list(self.causal_graph.nodes()),
            bidirected_edges=bidirected_edges,
            directed_edges=directed_edges,
        )

        # define NCM
        functions = {}
        var_names = cg.v
        # just copy some stuff from the NCM code to set the Simple and NF just like in the NCM code
        v_size = {}
        default_v_size = 1
        u_size = {}
        default_u_size = 1
        u_size = {k: u_size.get(k, default_u_size) for k in cg.c2}
        v_size = {k: v_size.get(k, default_v_size) for k in cg}

        for var in var_names:
            if var in bin_vars:
                # instrumental variables are binary
                functions[var] = Simple_binary(
                    {k: v_size[k] for k in cg.pa[var]},
                    {k: u_size[k] for k in cg.v2c2[var]},
                    v_size[var],
                    domain = data[var].unique()
                )
            elif var in dis_vars:
                functions[var] = Simple_discrete(
                    {k: v_size[k] for k in cg.pa[var]},
                    {k: u_size[k] for k in cg.v2c2[var]},
                    domain = data[var].unique()
                )
            else:
                # other variables are continuous
                functions[var] = Continuous(
                    {k: v_size[k] for k in cg.pa[var]},
                    {k: u_size[k] for k in cg.v2c2[var]},
                    v_size[var],
                )
        self.ncm = NCM_model(cg, f=functions).to(device)

        optim = torch.optim.AdamW(self.ncm.parameters(), 4e-3)
        lr_schedule = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optim, 50, 1, eta_min=1e-4
        )

        loss_train = []
        epoch_time, ate_train, cate_train, gpu_mem = [], [], [], []

        print(f"Training Neural Causal Model on {device}")
        for epoch in range(method_params["num_epochs"]):
            start_epoch_time = time.process_time()
            total_nll = 0.0
            for batch in train_loader:
                if single_batches_train:
                    for var in batch.keys():
                        batch[var] = batch[var][0]
                optim.zero_grad()

                if not use_interventional_data:
                    nll = self.ncm.biased_nll(
                        batch, n=1
                    ).mean()  # only sample one u per instance
                else:
                    # TODO: BUILD INTERVENTIONAL MASK
                    raise NotImplementedError("Not implemented yet")
                    nll = self.ncm.biased_nll_with_interventional_data(
                        batch, n=1, interv_index= int_idx
                    ).mean()  # only sample one u per instance

                nll.backward()

                total_nll += nll
                optim.step()

            lr_schedule.step()

            # Save info and print
            epoch_time.append(time.process_time() - start_epoch_time)

            results, runtime = self.estimate_effect(
                outcome=outcome,
                treatment=treatment,
                evidence=evidence,
                method_params=method_params,
                seed=seed,
                save_dir=save_dir,
            )
            ate_train.append(results["ATE"])
            cate_train.append(results["CATE"])
            loss_train.append(total_nll.item() / len(train_loader))
            gpu_mem.append(torch.cuda.max_memory_allocated())

            nll_print = round(total_nll.item() / len(train_loader), 5)
            epoch_time_print = round(epoch_time[-1], 5)
            current_epoch_str = f"Epoch {epoch} \t NLL: {nll_print} \t Epoch Time: {epoch_time_print} \t ATE: {ate_train[-1]} \t CATE: {cate_train[-1]} \t GPU Mem: {gpu_mem[-1]}"
            print(current_epoch_str)

        print("Training finished")

        torch.save(self.ncm.state_dict(), save_dir + "/model.pt")
        delta_train_time = sum(epoch_time)

        # Save train_loss
        train_info_df = pd.DataFrame(
            {
                "ATE": ate_train,
                "CATE": cate_train,
                "GPU Memory": gpu_mem,
                "Epoch Time": epoch_time,
                "Loss": loss_train,
            },
            index=range(method_params["num_epochs"]),
        )

        train_info_df.to_csv(os.path.join(save_dir, "train_info.csv"))

        runtime = {
            "Training Time": delta_train_time,
            "Avg. Epoch Time": np.average(epoch_time),
            "Avg. GPU Memory": np.average(gpu_mem),
        }
        return runtime

    def estimate_effect(
        self,
        outcome: str,
        treatment: Optional[dict[str, float]] = {},
        evidence: dict[str, float] = {},
        method_params: dict[str, Any] = {},
        seed: Optional[int] = None,
        save_dir: Optional[str] = None,
        data: Optional[pd.DataFrame] = None,
        int_table: Optional[pd.DataFrame] = None,
    ) -> tuple[dict[str, Any], dict[str, Any]]:
        # Convert the treatment and evidence dictionaries to the format expected by the do-calculus
        treatment_dict, control_dict = self.extract_treat_control(treatment)

        ### ATE ###
        ate, delta_estimate_time_ate, int_distr_treated, state_names, treated_samples, int_distr_control, control_samples = (
            self.get_average_effect(
                treatment=treatment_dict,
                control=control_dict,
                outcome=outcome,
                evidence={},
            )
        )

        ### CATE ###
        cate, delta_estimate_time_cate, int_distr_cate, _, conditional_treated_samples, cond_int_distr_control, conditional_control_samples = self.get_average_effect(
            treatment=treatment_dict,
            control=control_dict,
            outcome=outcome,
            evidence=evidence,
        )

        # Save results
        results = {
            "target": outcome,
            "state_names": list(state_names),
            "Interventional Distribution": int_distr_treated,
            "Conditional Interventional Distribution": int_distr_cate,
            "ATE": ate,
            "evidence": evidence if evidence != {} else None,
            "CATE": cate if evidence != {} else None,
            "Interventional Samples": treated_samples,
            "Conditional Interventional Samples": conditional_treated_samples,
            "Control Samples": control_samples,
            "Conditional Control Samples": conditional_control_samples,
            "Control Interventional Distribution": int_distr_control,
            "Conditional Control Distribution": cond_int_distr_control
        }

        runtime = {
            "Estimation Time ATE": delta_estimate_time_ate,
            "Estimation Time CATE": delta_estimate_time_cate,
        }

        return results, runtime
        
    def get_average_effect(self, treatment, control, outcome, evidence) -> float:
        """
        Calculate the average effect.

        Parameters:
        untreated_avg (float): The average of the untreated samples.
        treated_avg (float): The average of the treated samples.

        Returns:
        float: The average effect.

        Raises:
        None

        """

        time_start = time.process_time()
        # Samples without treatment (Control)
        num_samples = 10000
        control_samples = self.ncm.sampling(num_samples=num_samples, do=control)

        # Samples with treatment (Treatment)
        treated_samples = self.ncm.sampling(num_samples=num_samples, do=treatment)
        delta_estimate_time = time.process_time() - time_start

        for key in control_samples.keys():
            control_samples[key] = control_samples[key].squeeze().cpu().detach().numpy()
            treated_samples[key] = treated_samples[key].squeeze().cpu().detach().numpy()

        control_samples = pd.DataFrame(control_samples)
        treated_samples = pd.DataFrame(treated_samples)

        # Conditional and quantize the samples
        treated_samples, control_samples = self.condition_and_quantize(treated_samples, control_samples, evidence)

        full_treated_samples = treated_samples.copy()
        full_control_samples = control_samples.copy()

        # Convert samples to numpy arrays
        control_samples = control_samples[outcome]
        treated_samples = treated_samples[outcome]

        # Get distribution from samples
        control_samples = np.where(control_samples > 0, 1, -1)
        state_names, control_distr, bins = self.get_probability_distribution(
            control_samples, [-1, 1]
        )
        treated_samples = np.where(treated_samples > 0, 1, -1)
        state_names_2, treated_distr, bins2 = self.get_probability_distribution(
            treated_samples, [-1, 1]
        )

        # Calculate the average effect
        control_avg = np.mean(control_samples)
        treated_avg = np.mean(treated_samples)
        average_effect = treated_avg - control_avg

        return average_effect, delta_estimate_time, treated_distr, state_names, full_treated_samples, control_distr, full_control_samples


